/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.gluten.backendsapi.omni

import org.apache.gluten.backendsapi.SparkPlanExecApi
import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.execution._
import org.apache.gluten.expression.{ExpressionConverter, ExpressionMappings, ExpressionTransformer, GenericExpressionTransformer, OmniAliasTransformer, OmniFromUnixTimeTransformer, OmniHashExpressionTransformer, OmniUnixTimestampTransformer}
import org.apache.gluten.extension.columnar.FallbackTags
import org.apache.spark.{ShuffleDependency, SparkException}
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{GenShuffleWriterParameters, GlutenShuffleWriterWrapper, OmniColumnarBatchSerializer, OmniColumnarShuffleWriter, OmniShuffleUtil}
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Cast, DateDiff, ElementAt, Expression, FromUnixTime, Generator, GetMapValue, HashExpression, Like, Md5, NamedExpression, PosExplode, PythonUDF, UnixTimestamp}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, BroadcastMode, Partitioning}
import org.apache.spark.sql.execution.{ColumnarWriteFilesExec, FileSourceScanExec, GenerateExec, OmniColumnarShuffleExchangeExec, OmniExecUtil, SparkPlan}
import org.apache.spark.sql.execution.datasources.{FileFormat, HadoopFsRelation}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BuildSideRelation}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.{BinaryType, StringType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.gluten.datasources.orc.OmniOrcFileFormat
import org.apache.gluten.datasources.parquet.OmniParquetFileFormat
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.expression.ExpressionConverter.replaceWithExpressionTransformer
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat

class OmniSparkPlanExecApi extends SparkPlanExecApi {

  /**
   * Generate FilterExecTransformer.
   *
   * @param condition
   *   : the filter condition
   * @param child
   *   : the child of FilterExec
   * @return
   *   the transformer of FilterExec
   */
  override def genFilterExecTransformer(
      condition: Expression,
      child: SparkPlan): FilterExecTransformerBase = {
    FilterExecTransformer(condition, child)
  }

  /**
   * Generate Alias transformer
   *
   * @return
   *   a transformer for alias
   */
  override def genAliasTransformer(
      substraitExprName: String,
      child: ExpressionTransformer,
      original: Expression): ExpressionTransformer =
    OmniAliasTransformer(substraitExprName, child, original)

  /**
   * Generate FromUnixTime transformer
   *
   * @return
   *   a transformer for fromUnixTime
   */
  override def genFromUnixTimeTransformer(
      substraitExprName: String,
      children: Seq[ExpressionTransformer],
      original: FromUnixTime): ExpressionTransformer =
    OmniFromUnixTimeTransformer(substraitExprName, children, original)

  override def genUnixTimestampTransformer(
      substraitExprName: String,
      children: Seq[ExpressionTransformer],
      original: UnixTimestamp): ExpressionTransformer =
    OmniUnixTimestampTransformer(substraitExprName, children, original)

  /** Generate HashAggregateExecTransformer. */
  override def genHashAggregateExecTransformer(
      requiredChildDistributionExpressions: Option[Seq[Expression]],
      groupingExpressions: Seq[NamedExpression],
      aggregateExpressions: Seq[AggregateExpression],
      aggregateAttributes: Seq[Attribute],
      initialInputBufferOffset: Int,
      resultExpressions: Seq[NamedExpression],
      child: SparkPlan): HashAggregateExecBaseTransformer = {
    OmniHashAggregateExecTransformer(
      requiredChildDistributionExpressions,
      groupingExpressions,
      aggregateExpressions,
      aggregateAttributes,
      initialInputBufferOffset,
      resultExpressions,
      child)
  }

  override def extraExpressionConverter(
      substraitExprName: String,
      expr: Expression,
      attributeSeq: Seq[Attribute]): Option[ExpressionTransformer] = expr match {
    case md5: Md5 =>
      md5.child match {
        case Cast(inputExpression, outputType, _, _) if outputType == BinaryType =>
          inputExpression match {
            case AttributeReference(_, dataType, _, _) if dataType == StringType =>
              val newCast = Cast(inputExpression, md5.dataType)
              Some(GenericExpressionTransformer(
                substraitExprName,
                newCast.children.map(ExpressionConverter.replaceWithExpressionTransformer(_, attributeSeq)),
                newCast
              ))
            case _ =>
              throw new GlutenNotSupportException(s"Not supported: $expr.")
          }
        case _ =>
          throw new GlutenNotSupportException(s"Not supported: $expr.")
      }
    case _ => None
  }

  /** Generate HashAggregateExecPullOutHelper */
  override def genHashAggregateExecPullOutHelper(
      aggregateExpressions: Seq[AggregateExpression],
      aggregateAttributes: Seq[Attribute]): HashAggregateExecPullOutBaseHelper =
      OmniHashAggregateExecPullOutBaseHelper(aggregateExpressions, aggregateAttributes)

  override def genColumnarShuffleExchange(shuffle: ShuffleExchangeExec): SparkPlan = {
    val child = shuffle.child
    val columnarConf = GlutenConfig.get
    val isRowShuffle = columnarConf.enableOmniRowShuffle &&
      shuffle.output.length > columnarConf.omniRowShuffleColumnsThreshold
    val newShuffle = OmniColumnarShuffleExchangeExec(shuffle, child, shuffle.output, isRowShuffle)
    val validationResult = newShuffle.doValidate()
    if (validationResult.ok()) {
      newShuffle
    } else {
      FallbackTags.add(shuffle, validationResult)
      shuffle.withNewChildren(child :: Nil)
    }
  }

  /** Generate ShuffledHashJoinExecTransformer. */
  override def genShuffledHashJoinExecTransformer(
      leftKeys: Seq[Expression],
      rightKeys: Seq[Expression],
      joinType: JoinType,
      buildSide: BuildSide,
      condition: Option[Expression],
      left: SparkPlan,
      right: SparkPlan,
      isSkewJoin: Boolean): ShuffledHashJoinExecTransformerBase = {
    OmniShuffledHashJoinExecTransformer(
      leftKeys, 
      rightKeys, 
      joinType, 
      buildSide, 
      condition, 
      left, 
      right, 
      isSkewJoin,
      null
    )
  }

  /** Generate BroadcastHashJoinExecTransformer. */
  override def genBroadcastHashJoinExecTransformer(
      leftKeys: Seq[Expression],
      rightKeys: Seq[Expression],
      joinType: JoinType,
      buildSide: BuildSide,
      condition: Option[Expression],
      left: SparkPlan,
      right: SparkPlan,
      isNullAwareAntiJoin: Boolean): BroadcastHashJoinExecTransformerBase = {
    OmniBroadcastHashJoinExecTransformer(
      leftKeys, 
      rightKeys,
      joinType,
      buildSide, 
      condition,
      left,
      right,
      isNullAwareAntiJoin,
      null
    )
  }

  override def genSampleExecTransformer(
      lowerBound: Double,
      upperBound: Double,
      withReplacement: Boolean,
      seed: Long,
      child: SparkPlan): SampleExecTransformer = null

  /** Generate ShuffledHashJoinExecTransformer. */
  override def genSortMergeJoinExecTransformer(
      leftKeys: Seq[Expression],
      rightKeys: Seq[Expression],
      joinType: JoinType,
      condition: Option[Expression],
      left: SparkPlan,
      right: SparkPlan,
      isSkewJoin: Boolean,
      projectList: Seq[NamedExpression]): SortMergeJoinExecTransformerBase = {
    OmniSortMergeJoinExecTransformer(
      leftKeys,
      rightKeys,
      joinType,
      condition,
      left,
      right,
      isSkewJoin,
      projectList
    )
  }

  /** Generate CartesianProductExecTransformer. */
  override def genCartesianProductExecTransformer(
      left: SparkPlan,
      right: SparkPlan,
      condition: Option[Expression]): CartesianProductExecTransformer = null

  override def genBroadcastNestedLoopJoinExecTransformer(
      left: SparkPlan,
      right: SparkPlan,
      buildSide: BuildSide,
      joinType: JoinType,
      condition: Option[Expression]): BroadcastNestedLoopJoinExecTransformer = {
    OmniBroadcastNestedLoopJoinExecTransformer(
      left,
      right,
      buildSide,
      joinType,
      condition
    )
  }

  /** Generate an expression transformer to transform GetMapValue to Substrait. */
  override def genGetMapValueTransformer(
      substraitExprName: String,
      left: ExpressionTransformer,
      right: ExpressionTransformer,
      original: GetMapValue): ExpressionTransformer = GenericExpressionTransformer(
    ExpressionMappings.expressionsMap(classOf[ElementAt]),
    Seq(left, right),
    original)

  override def genHashExpressionTransformer(
      substraitExprName: String,
      exprs: Seq[ExpressionTransformer],
      original: HashExpression[_]): ExpressionTransformer = {
    OmniHashExpressionTransformer(substraitExprName, exprs, original)
  }

  /** Transform GetArrayItem to Substrait. */
  override def genGetArrayItemTransformer(
      substraitExprName: String,
      left: ExpressionTransformer,
      right: ExpressionTransformer,
      original: Expression): ExpressionTransformer = {
    GenericExpressionTransformer(substraitExprName, Seq(left, right), original)
  }

  /** Transform posexplode to Substrait. */
  override def genPosExplodeTransformer(
      substraitExprName: String,
      child: ExpressionTransformer,
      original: PosExplode,
      attributeSeq: Seq[Attribute]): ExpressionTransformer = null

  /**
   * Generate ShuffleDependency for ColumnarShuffleExchangeExec.
   *
   * childOutputAttributes may be different from outputAttributes, for example, the
   * childOutputAttributes include additional shuffle key columns
   *
   * @return
   */
  override def genShuffleDependency(
      rdd: RDD[ColumnarBatch],
      childOutputAttributes: Seq[Attribute],
      outputAttributes: Seq[Attribute],
      newPartitioning: Partitioning,
      serializer: Serializer,
      writeMetrics: Map[String, SQLMetric],
      metrics: Map[String, SQLMetric],
      isSort: Boolean): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = {

    // scalastyle:on argcount
    OmniShuffleUtil.genShuffleDependency(
      rdd,
      childOutputAttributes,
      newPartitioning,
      serializer,
      writeMetrics,
      metrics,
      isSort)
  }

  /** Determine whether to use sort-based shuffle based on shuffle partitioning and output. */
  override def useSortBasedShuffle(partitioning: Partitioning, output: Seq[Attribute]): Boolean =
    false

  /**
   * Generate ColumnarShuffleWriter for ColumnarShuffleManager.
   *
   * @return
   */
  override def genColumnarShuffleWriter[K, V](
      parameters: GenShuffleWriterParameters[K, V]): GlutenShuffleWriterWrapper[K, V] = {

    GlutenShuffleWriterWrapper(
      new OmniColumnarShuffleWriter[K, V](
        parameters.shuffleBlockResolver,
        parameters.columnarShuffleHandle,
        parameters.mapId,
        parameters.metrics))
  }

  /**
   * Generate ColumnarBatchSerializer for ColumnarShuffleExchangeExec.
   *
   * @return
   */
  override def createColumnarBatchSerializer(
      schema: StructType,
      metrics: Map[String, SQLMetric],
      isSort: Boolean): Serializer = {
    val readBatchNumRows = metrics("avgReadBatchNumRows")
    val numOutputRows = metrics("numOutputRows")
    val columnarConf = GlutenConfig.get
    val isRowShuffle = columnarConf.enableOmniRowShuffle &&
      schema.length > columnarConf.omniRowShuffleColumnsThreshold
    new OmniColumnarBatchSerializer(readBatchNumRows, numOutputRows, isRowShuffle)
  }

  /** Create broadcast relation for BroadcastExchangeExec */
 override def createBroadcastRelation(
      mode: BroadcastMode,
      child: SparkPlan,
      numOutputRows: SQLMetric,
      dataSize: SQLMetric): BuildSideRelation = {

    val input = OmniExecUtil.buildSideRDD(mode, child)
      .collect()
    val relation = OmniColumnarBuildSideRelation(mode, child.output, input.map(_.getBatches))
    dataSize.add(input.map(_.getBatches.length).sum)
    numOutputRows.add(input.map(_.getRowNum).sum)
    if (dataSize.value >= BroadcastExchangeExec.MAX_BROADCAST_TABLE_BYTES) {
      throw new SparkException(
        s"Cannot broadcast the table that is larger than 8GB: ${dataSize.value >> 30} GB")
    }
    // todo: add EmptyHashedRelation & HashedRelationWithAllNullKeys
    relation
  }

  /** Create ColumnarWriteFilesExec */
  override def createColumnarWriteFilesExec(
      child: WriteFilesExecTransformer,
      noop: SparkPlan,
      fileFormat: FileFormat,
      partitionColumns: Seq[Attribute],
      bucketSpec: Option[BucketSpec],
      options: Map[String, String],
      staticPartitions: TablePartitionSpec): ColumnarWriteFilesExec = null

  /** Create ColumnarArrowEvalPythonExec, for omni backend */
  override def createColumnarArrowEvalPythonExec(
      udfs: Seq[PythonUDF],
      resultAttrs: Seq[Attribute],
      child: SparkPlan,
      evalType: Int): SparkPlan = null

  override def genLikeTransformer(
      substraitExprName: String,
      left: ExpressionTransformer,
      right: ExpressionTransformer,
      original: Like): ExpressionTransformer = {
    GenericExpressionTransformer(
      substraitExprName,
      Seq(left, right),
      original)
  }

  override def genDateDiffTransformer(
      substraitExprName: String,
      endDate: ExpressionTransformer,
      startDate: ExpressionTransformer,
      original: DateDiff): ExpressionTransformer = null

  override def genGenerateTransformer(
      generator: Generator,
      requiredChildOutput: Seq[Attribute],
      outer: Boolean,
      generatorOutput: Seq[Attribute],
      child: SparkPlan): GenerateExecTransformerBase = {
    OmniGenerateExecTransformer(generator, requiredChildOutput, outer, generatorOutput, child)
  }

  override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan = null

  override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = null

  override def maybeCollapseTakeOrderedAndProject(plan: SparkPlan): SparkPlan = {
    // This to-top-n optimization assumes exchange operators were already placed in input plan.
    plan.transformUp {
      case p @ LimitExecTransformer(SortExecTransformer(sortOrder, _, child, _), 0, count) =>
        val global = child.outputPartitioning.satisfies(AllTuples)
        val topN = OmniTopNTransformer(count, sortOrder, global, child)
        if (topN.doValidate().ok()) {
          topN
        } else {
          p
        }
      case other => other
    }
  }

  override def genFileSourceScanExecTransformer(
    scanExec: FileSourceScanExec): FileSourceScanExecTransformerBase = {
    val hadoopFsRelation = scanExec.relation
    val fileFormat: FileFormat = hadoopFsRelation.fileFormat match {
        case orcFormat: OrcFileFormat =>
            new OmniOrcFileFormat()
        case parquetFormat: ParquetFileFormat =>
            new OmniParquetFileFormat()
        case _ =>
            hadoopFsRelation.fileFormat
      }
    val newRelation = HadoopFsRelation(
      hadoopFsRelation.location,
      hadoopFsRelation.partitionSchema,
      hadoopFsRelation.dataSchema,
      hadoopFsRelation.bucketSpec,
      fileFormat = fileFormat,
      hadoopFsRelation.options)(SparkSession.active)

    FileSourceScanExecTransformer(
      newRelation,
      scanExec.output,
      scanExec.requiredSchema,
      scanExec.partitionFilters,
      scanExec.optionalBucketSet,
      scanExec.optionalNumCoalescedBuckets,
      scanExec.dataFilters,
      scanExec.tableIdentifier,
      scanExec.disableBucketedScan
    )
  }

  override def genPromotePrecisionTransformer(
                                      cast: Cast,
                                      attributeSeq: Seq[Attribute]): ExpressionTransformer = {
    replaceWithExpressionTransformer(cast, attributeSeq)
  }
}
